Optimization - SGD

Lecture 15

Dr. Colin Rundel

Stochastic Gradient Descent

A regression examples

from sklearn.datasets import make_regression
X, y, coef = make_regression(
  n_samples=200, n_features=20, n_informative=4, 
  bias=3, noise=1, random_state=1234, coef=True
)
y
array([ -36.2252,    9.6357,   66.4583,   48.9574,   24.1885,  -13.2444,
         18.1455, -135.047 ,  116.5772,   60.2524,   30.9319,  107.148 ,
         21.6209,   66.2401, -132.8878,   58.636 ,   22.186 ,   60.3852,
        -85.0383,   55.1704,  -31.3817,  -57.0697,   67.3215,    2.878 ,
        -29.5613,  -41.3973,  -30.3048,  -41.5597,   52.7531,  -63.5633,
          3.5671,   63.712 ,    9.9833,   78.4881,  -76.126 ,   13.4331,
        122.6162,   79.0354,   91.2171,   48.7344,  103.6366,   52.5964,
         35.0064,  -65.8423,  -47.3045,  -25.6876,    1.8359,   35.4113,
         28.0687,   56.3528,    3.6755,  -72.3309,   57.143 ,  -16.9438,
         54.1445,   72.6828,   -5.0538, -180.6135,  -44.6205,    9.2071,
         -5.5324,  -29.6013,  135.3656,  114.241 ,  -97.4878,   15.0648,
         14.7958,   71.503 ,   -4.6583,  -36.791 ,   -5.3845, -119.8073,
         11.174 ,   36.3008,   82.5499,  -20.0869,   14.7146,  -59.0765,
         39.4171,   48.4013,  -61.9613,   -5.6247,  103.2374,   41.2613,
       -129.2273,   10.5113,   32.4936,   78.6921,    5.2956,   64.4473,
         88.8358,   39.4851,  -11.4866,  -52.5082,  112.7248,   -9.7006,
         13.8393,  -36.4004,   68.4865,   19.5335,  -75.447 ,  -87.9538,
         79.4784,  -75.094 ,   25.6229,   84.9034,   71.2779,  -66.4093,
         77.6444,   40.8875,   31.3165,  -22.7143,   84.562 ,    6.8075,
          9.778 ,  -65.9149,  106.6952,   -3.1901,   41.1555,   32.6265,
        -36.5738,   38.9966,  -78.664 ,  -56.0434,    2.9191,   42.6286,
         51.3644,  -21.8072,  -21.9779,  -15.7102,  -23.5586,    1.3801,
         20.4269,   55.7188,  -45.6388,  -55.1542,   74.6067,   -7.2716,
        -31.1045,   48.1571,   14.7487,   41.6956,  -59.6062,  -33.0811,
         81.0177,   -9.4896,  164.1317,   25.3507,    6.0141,   46.3718,
         84.2983,  -63.2593,  -17.4733,  -26.2977,  -56.4681,   17.003 ,
         53.1867,  -94.5398,  -18.2541,  -49.343 ,   40.8724,  -90.5986,
         27.9392,   41.7287,   49.8082,   -9.6384,  -66.7551,  122.9159,
        -41.3566,  -98.6863,  -45.0718,    9.9327,  -22.0927,   10.6199,
        -12.2831,    7.4184,   57.6091,  -27.3456,  -36.4045,  -51.659 ,
         28.8175,  -23.9402,  -51.0637,    4.3618,   10.8402,  -11.087 ,
        -29.9801,  113.6633,   66.5601,    1.3808,  -19.4875,   40.812 ,
         43.0652,   35.4802,   77.0732,  -49.7352,   65.7192,   73.8539,
        -59.4116,   72.9501])
X
array([[-0.6465,  2.0803,  0.1412, -0.8419, -0.1595,  1.3321, -0.4262,
        -0.0351, -0.1938, -0.6093, -0.3433,  0.6126,  0.3777, -1.2062,
        -0.2277, -0.8896, -0.4674, -1.3566,  1.4989, -0.7468],
       [-0.3834, -0.3631, -1.2196,  0.6   ,  0.3315,  1.1056,  0.2662,
        -0.7239,  0.0259, -0.2172, -0.6841,  0.0991,  0.2794, -1.208 ,
        -0.7818, -1.7348, -1.3397, -0.5723, -0.5882,  0.2717],
       [-0.1637, -0.8118,  0.9551,  0.5711,  0.8719, -0.9619,  1.9846,
        -1.1806, -1.1261,  0.297 ,  1.2499,  0.7109, -0.1183,  0.6708,
         0.6895,  1.4705,  0.0634, -0.3079, -2.2512, -0.0216],
       [-0.9292, -0.4897, -2.1196, -1.142 ,  1.266 , -0.2988,  1.0016,
        -2.1969, -1.0739, -0.1149,  0.5122,  0.302 , -0.0974,  1.3461,
         0.1909,  1.1223,  0.6268,  2.2035, -0.5135,  2.0118],
       [ 0.1645, -0.5847,  0.2708, -3.5635,  0.1526,  0.5283,  0.7674,
         1.392 , -0.0819,  1.3211,  0.4644, -1.0279,  0.9849, -1.069 ,
        -0.4301,  0.0798, -0.5119, -0.3448,  0.8166, -0.4   ],
       [ 0.4134,  1.9511, -0.5013, -1.4894,  0.4191, -1.4104,  0.2617,
        -0.6981,  0.0368, -1.151 ,  2.0752,  0.5001, -0.2428,  0.45  ,
         0.7176,  1.3846,  0.5155,  0.4459, -0.2784, -0.2864],
       [-0.0628, -1.424 , -1.1023,  0.1445, -0.4836,  1.4795, -0.5921,
         1.6423, -0.5013,  0.4435,  2.0044,  0.6221,  0.0747, -1.4117,
        -0.202 , -1.3071, -0.8656, -1.311 ,  0.0424,  0.7255],
       [-0.6642,  1.4317, -0.0658, -0.7379, -0.9153,  0.8653,  0.7143,
         1.0912, -1.3773, -2.6022, -0.2955, -0.3985,  0.0918,  0.3851,
         0.502 , -0.4665,  1.6432, -0.2438, -0.4943,  1.4753],
       [ 1.5247, -1.3419, -0.4453, -0.6141,  2.0632, -1.0742, -1.4419,
        -1.4923,  0.3135,  0.7691,  0.5383, -0.9741,  0.8457, -0.0014,
         0.3895,  0.2118,  1.0977, -0.4036, -1.5496, -0.3672],
       [ 1.5524, -1.1109, -0.5624, -0.9106, -0.0506,  0.8533, -0.5452,
        -1.7836, -0.8365,  2.171 , -0.6158,  0.2523, -1.8707,  0.6142,
         0.7962,  0.0706, -0.2386, -0.4144, -0.0898, -0.4745],
       [-0.7206, -2.0213,  0.0157, -1.191 , -0.3127,  0.2891,  0.8596,
        -2.2427,  0.0021,  1.4327,  0.4714,  0.9533, -0.6365,  1.3212,
         0.8872,  1.15  , -1.5469,  0.4055, -0.3341,  0.9919],
       [ 0.2328,  0.5523, -0.2356, -1.2547,  0.6686, -2.1204, -0.186 ,
         1.4915, -1.1353,  2.3889,  0.3449, -0.6703, -0.2358, -2.1923,
        -0.4635, -0.9962, -0.1116,  0.0605,  0.0027, -1.439 ],
       [ 1.8092, -1.5857, -0.9765,  1.6171,  0.368 , -0.2947,  1.5897,
        -0.8878,  1.0547, -0.0427, -1.187 ,  0.7605,  1.2381, -0.5014,
         1.0201, -0.5773, -0.632 , -0.502 , -1.6914,  0.803 ],
       [ 0.1935,  0.5289, -0.7559, -0.1047, -0.3334, -1.0275,  1.0327,
        -0.8811,  0.0483,  1.8504,  1.5727,  0.3325, -1.7398, -0.2383,
        -0.4967,  0.3939,  1.9322,  0.062 , -1.1205, -0.95  ],
       [ 0.6914, -2.0308,  1.1554, -0.4219, -1.6257, -0.1138, -0.9225,
        -1.9216,  1.2995, -1.5084, -0.863 ,  0.2528,  1.3636,  0.2059,
         0.0381,  1.1124,  1.73  ,  0.4496, -0.1806,  0.7681],
       [-0.0862, -0.2131, -0.5343, -0.1066, -0.8403,  1.3862,  0.5885,
        -1.089 , -0.8571,  2.0178,  2.6078, -0.5807, -0.3466, -0.5166,
        -0.7863,  0.2918, -0.1904, -0.8012, -1.6868,  0.2538],
       [-1.0714, -0.4582,  0.4255,  0.5657, -0.1743,  2.0978, -0.8453,
        -0.9807, -0.0414,  0.5851,  0.2645, -0.3602,  0.4151,  1.2829,
        -0.0485, -0.4278,  0.2703,  0.821 , -1.338 ,  1.4986],
       [ 0.2126,  2.1145, -0.1471,  1.7549,  0.9465, -1.3906, -1.0954,
        -0.5224,  0.5338,  0.0591, -0.2671,  1.5731,  0.3903,  0.6137,
        -0.5277,  0.6306,  0.7467,  1.7232,  0.62  ,  2.0249],
       [ 0.7085,  1.312 , -0.6134,  0.8665, -1.4706,  0.2597, -0.1606,
        -0.7118,  0.2154, -0.7415, -0.608 , -0.3412,  1.0772,  0.4695,
        -0.1285,  0.0654,  0.4922, -0.6707, -1.8229, -0.4215],
       [ 1.2418, -1.9068, -0.6066,  0.1639,  0.986 , -0.1853, -0.0303,
         1.152 , -0.161 ,  0.0226,  0.8991,  0.9874, -0.802 , -0.7241,
         0.2466,  0.747 , -0.9682, -1.1908,  0.4313, -1.2039],
       [-0.6263,  0.2757,  0.9388,  1.3835, -0.5935,  0.4409, -1.4681,
         0.0114, -0.3643, -0.3373, -1.3341,  0.0036,  0.5513, -0.1016,
         0.6814, -1.4258, -1.3869, -2.0679, -1.6482,  1.0062],
       [-0.2397,  0.6481,  1.7758,  0.0166, -1.7724, -0.1862,  1.118 ,
        -0.8409,  0.6136,  0.5269, -0.2908, -0.2294,  0.1747, -0.3881,
        -0.2667, -0.7601,  0.4313, -0.7488, -0.7594, -0.4084],
       [-1.0989,  1.1887, -0.5288,  1.6782,  0.3827,  0.4309, -1.3949,
         0.6801, -1.2572,  0.6585,  0.7674, -1.5397,  1.1786,  1.2429,
        -1.1094,  1.2524, -0.7556,  0.4051, -0.3198,  0.6704],
       [ 0.0582,  0.1247,  0.1058, -0.4947, -0.1381,  1.3226,  0.3375,
         0.0445,  1.2923,  0.67  , -1.3132, -0.7997, -0.1669,  1.5938,
        -0.7805, -0.3689, -2.5977, -1.2921, -1.2897, -0.074 ],
       [-0.9515, -1.0973,  1.5675,  0.0103, -1.1347,  0.165 ,  0.0289,
        -0.6242, -1.3193,  0.2246,  0.7557, -0.9032,  2.1041, -0.6316,
        -0.1271, -0.4006, -0.8671, -0.5601, -0.0713, -1.1371],
       [ 0.4599,  0.5513,  1.6362, -1.2392, -0.3352,  1.0237,  1.7626,
        -0.5441,  1.3217, -1.2237,  2.5112, -1.7501, -0.0857,  0.8239,
        -0.6406, -1.05  , -0.635 , -2.1445,  1.4129,  0.2546],
       [ 0.1871, -2.2206,  1.2475,  1.2345, -1.5021,  1.1434, -1.0406,
         0.0709,  1.2826,  0.5946, -0.176 ,  0.0639, -1.4364, -0.3326,
        -0.4648,  0.0733, -1.5075,  0.7799, -0.6549,  0.2562],
       [ 0.084 ,  0.9564,  0.366 , -0.6843, -0.6239,  0.3233, -0.4753,
        -0.7024, -0.8606, -0.8089,  1.7968, -0.9079, -0.1103, -0.8212,
         1.328 , -1.2039, -2.1219, -0.8672, -1.345 , -1.0769],
       [-0.807 , -0.037 ,  0.7597,  0.9556,  0.1334, -0.0225,  1.9088,
        -0.423 ,  0.267 ,  0.7138,  0.996 ,  0.0679,  0.1559,  0.1314,
        -0.342 ,  0.1817,  0.4344,  1.383 , -0.1708,  0.2745],
       [-0.3675, -0.93  ,  1.2117,  1.0203, -1.1554, -0.0461,  0.827 ,
         1.793 , -1.0029, -0.7901,  0.0797,  0.992 , -0.5725,  1.3592,
         1.2639,  1.3791,  0.021 , -2.5727, -0.2494,  2.0499],
       ...,
       [ 1.6243, -1.2413, -0.4177,  0.2389, -0.2734, -0.6785, -1.0147,
        -0.2772, -1.711 , -1.1543,  0.2933,  1.487 ,  0.7526,  0.6561,
         0.4132,  0.1095,  0.1406, -0.6598,  1.2687,  1.2148],
       [ 0.6854, -0.7399, -1.0681,  0.2991,  0.0382, -0.9321, -0.8341,
         0.0215,  0.0612, -0.129 ,  0.8795, -0.1681,  0.8851,  1.2921,
         0.3478,  1.5717,  2.4181, -0.0638,  1.3938,  0.884 ],
       [ 2.2326, -1.7645,  1.9779, -1.6875, -0.8401,  0.1057,  1.1688,
         0.3301, -0.5216,  1.207 , -1.5042,  1.6341, -1.0896, -0.7015,
        -1.7587,  1.4814,  0.6081, -0.7485,  2.1342, -0.4016],
       [ 2.2171, -0.6177,  0.1949, -1.0798,  0.586 , -0.859 ,  2.5508,
        -0.8039,  0.1503, -0.1069, -0.6496, -0.2479,  0.1649,  0.765 ,
         0.8986, -0.3648,  0.6722, -0.2408,  0.7112,  0.1551],
       [ 1.1797,  2.0229,  0.2965, -0.4986,  0.6617,  0.8841, -0.8252,
        -0.3799, -1.1173, -1.3918,  0.9206, -0.076 , -0.8812, -1.7954,
         0.2918,  0.7677,  0.4183,  1.236 , -0.1036, -0.0952],
       [-1.4325,  1.5241,  0.4914,  0.4466, -1.5217, -0.5697, -0.1623,
        -0.0357, -1.3161,  1.733 ,  0.7872,  1.4468,  1.8372,  1.0749,
        -2.0308, -0.2996, -1.1323,  0.6271,  0.7217,  0.8836],
       [ 0.4932, -1.2439,  0.5748, -0.1409,  1.7359, -1.1483, -0.4902,
        -0.5052, -0.4267, -0.6533,  0.242 ,  0.7283,  0.2963, -1.2347,
         0.6998,  0.7025, -0.5894, -2.7557, -1.1078, -0.5546],
       [ 0.7622, -1.2367, -0.9891,  1.7989,  0.1187, -1.8608, -0.559 ,
         0.775 ,  0.0616, -1.6046,  0.2385, -1.5886, -0.1833, -0.7817,
         1.8364, -0.5933, -0.3687,  0.3881,  1.2738,  1.2086],
       [-0.7038, -1.2434,  0.5626,  0.3224,  0.3713,  0.7815,  1.957 ,
         0.4423,  0.6326, -1.956 ,  1.1085,  0.1665,  0.841 ,  2.1472,
         0.5566, -0.2651, -0.9084,  2.0134,  0.3486,  1.2223],
       [-0.8597, -1.2391,  0.9525, -0.7438, -0.9162,  0.1223,  0.6288,
         0.9881, -0.223 , -0.3202,  0.5368, -0.9382,  0.1865, -1.4094,
         0.226 , -0.0726,  1.423 ,  2.1237,  0.1397, -0.5506],
       [-0.0956,  0.2007, -0.3942,  0.812 ,  0.6777, -2.5834, -1.3294,
        -0.6009, -1.0962, -0.359 ,  0.0455, -0.5706,  0.0263, -0.9308,
        -0.0649,  0.6586, -0.0469, -1.0283,  0.3524, -0.3676],
       [-1.4168, -0.0501, -0.5261, -0.3774, -0.9222,  0.253 , -0.2044,
         0.0524,  0.8973,  0.0268,  1.6289,  0.4335, -2.1098, -0.488 ,
         0.8635, -1.7442,  0.6374, -1.3713, -0.6505,  0.0552],
       [-0.1955, -0.6314,  1.0877, -0.9238, -1.2701,  0.3528,  0.9894,
         0.4388, -0.2405,  0.3558, -0.2266,  0.5029,  1.3886, -1.8156,
        -0.4634, -0.9616, -0.9101,  0.5856, -0.7043,  1.2456],
       [ 0.1021, -0.9809,  0.3537,  1.6762, -0.7037, -0.2284, -0.278 ,
        -0.4083,  0.666 ,  0.681 , -0.9055,  1.054 , -0.0522,  0.3645,
         1.1951, -1.8104, -1.5148,  1.0655,  0.3521, -0.9033],
       [ 0.2733, -0.0657, -0.5118, -0.9471, -0.4646,  1.4039,  0.5143,
         0.0839, -1.1759,  0.5828,  1.7003, -0.4077,  0.5941, -0.7209,
        -0.2797,  0.6193,  1.001 ,  1.6007, -0.6754,  0.2706],
       [ 1.0588, -2.1785,  0.2624,  0.1752, -0.0675,  0.0093, -1.8534,
        -1.7246,  1.5724, -0.5295,  0.4144,  0.6353, -0.7018, -1.1284,
        -0.1179,  0.2766, -1.2591,  2.0028,  0.312 ,  1.073 ],
       [ 0.428 , -0.2036, -1.3212,  1.1555, -0.439 , -1.6271, -0.1925,
         0.622 , -1.5984, -0.5565, -0.5401,  0.0358,  0.0133, -1.4583,
         2.0069,  0.1263,  0.4161, -0.9847,  0.7966,  0.486 ],
       [-0.1173,  0.6381, -0.2812, -0.3578,  1.9299,  2.1058,  2.3049,
         1.0974,  0.5727,  0.9601, -0.211 ,  0.6298,  0.7401, -1.7709,
        -1.0439,  0.8751, -1.5677,  0.3412, -0.1276, -0.8195],
       [ 0.2991, -0.7227, -0.4858,  0.2055,  0.8665,  0.5538,  0.3632,
         0.3877,  0.9835,  0.313 ,  1.5294, -0.3187,  1.8937,  0.3538,
         1.0765,  0.0236, -0.2756,  0.0235,  0.1774, -0.6602],
       [ 0.2811,  0.94  ,  0.2862,  0.2612, -0.3128,  1.8086, -0.1915,
         0.5764,  0.5697,  0.8199, -2.1527, -1.3712, -0.6986,  0.3284,
        -0.4288,  0.2109, -0.4925, -0.3614,  0.4904, -2.0268],
       [-0.1443, -1.0635, -1.0286, -0.5993, -0.0703,  1.1204, -0.0037,
        -0.4246, -0.3002, -0.7602,  1.4965,  2.5779,  0.7647,  0.1203,
         0.1931,  0.7629,  1.2026, -0.8532,  0.1837,  0.5151],
       [ 0.4857,  0.4299,  0.1458,  0.7444, -0.0979,  1.0103,  0.6701,
        -0.2955,  0.0965,  1.4683, -1.7762,  0.9793, -0.2762, -0.0483,
         0.0385, -0.5835,  0.6415,  0.0514,  0.3048,  1.6422],
       [ 0.4843,  0.7492,  1.2246,  0.5665,  0.2853, -1.8185, -0.7811,
        -1.2811, -0.1822,  0.5036,  0.2912, -0.4508, -0.468 ,  0.0471,
         1.3635,  0.8755,  0.3948,  0.6807, -0.2039, -1.7107],
       [-1.0209, -0.6634, -0.9163,  1.3832,  0.0976,  1.3228,  0.1802,
         0.2148,  1.6929,  0.4813, -0.2126,  0.7982, -1.1308, -0.6505,
         0.0008,  1.4194, -0.0373,  0.3867,  0.2306, -0.5345],
       [-0.7579,  1.397 ,  1.337 , -1.8439,  2.0066,  1.6138, -1.5925,
        -0.2433,  0.1118,  0.11  , -0.0658,  0.3186,  0.2924, -0.2974,
         1.016 , -0.231 ,  1.639 ,  0.4316, -0.8798, -0.3389],
       [ 0.7059,  1.752 , -0.0646, -0.0381, -0.2057,  0.6298, -0.0253,
        -1.205 , -0.1709, -0.6655, -2.0803,  0.4152, -0.1783,  0.8111,
        -2.6128, -3.8809,  2.1338,  0.7489,  0.485 ,  0.9745],
       [-1.1806, -0.3993, -0.2226, -2.3027,  1.3672, -0.1536, -1.9393,
         0.2231, -0.0869,  0.9348, -0.4875, -0.8112,  0.6287,  0.2276,
        -0.2909,  0.0516,  1.3023,  0.5602,  0.515 ,  0.4992],
       [-0.6427, -0.3821,  0.1691,  0.1282,  0.4937,  0.4471, -0.2522,
         0.8716, -1.5433,  1.6954, -0.9341, -0.5208, -1.0103,  0.9605,
        -1.877 ,  0.8052, -1.7248, -1.3744, -0.4322,  1.1537],
       [ 0.8806, -0.7055, -0.0158,  0.172 , -0.4213,  0.8811,  0.7577,
        -0.3878, -0.6382, -1.365 ,  0.1341,  1.7316, -0.6366, -0.6532,
        -1.4726,  0.8897, -1.32  ,  0.7008, -1.2858,  1.1342],
       [-0.5563, -1.0968, -0.9132, -1.2012,  1.8192, -1.1012,  0.805 ,
         2.0475,  2.122 ,  0.0836, -0.1508, -0.6059,  1.1388, -1.2726,
        -1.0138, -0.8948,  0.8435,  1.1215,  0.2176,  1.0591]],
      shape=(200, 20))
coef
array([ 0.    ,  0.    ,  0.    ,  9.6106, 43.4239,  0.    ,  0.    ,
        0.    ,  0.    , 34.453 ,  9.2929,  0.    ,  0.    ,  0.    ,
        0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ])

Minimalistic GD for LR

def grad_desc_lm(X, y, beta, step, max_step=50):
  X = jnp.c_[jnp.ones(X.shape[0]), X]
  f = lambda beta: jnp.sum((y - X @ beta)**2)
  grad = jax.grad(f)
  
  res = {"x": [beta], "loss": [f(beta).item()], "iter": [0]}
  
  for i in range(max_step):
    beta = beta - grad(beta) * step
    res["x"].append(beta)
    res["loss"].append(f(beta).item())
    res["iter"].append(res["iter"][-1]+1)
    
  return res

Linear regression

lm = LinearRegression().fit(X,y)
np.r_[lm.intercept_, lm.coef_]
array([ 3.0616, -0.0121, -0.0096,  0.096 ,  9.6955, 43.406 ,  0.0253,
        0.0284,  0.0962,  0.1069, 34.4884,  9.3445, -0.0165, -0.0147,
       -0.0396,  0.0969, -0.1057, -0.0943,  0.11  , -0.0096, -0.0875])


gd_lm = grad_desc_lm(
  X, y, np.zeros(X.shape[1]+1), 
  step = 0.001, max_step=25
)
gd_lm["x"][-1]
Array([ 3.0634, -0.0118, -0.0136,  0.097 ,  9.7003, 43.4074,  0.0301,
        0.0292,  0.0976,  0.1007, 34.479 ,  9.3418, -0.0197, -0.0123,
       -0.0427,  0.0878, -0.106 , -0.087 ,  0.1106, -0.018 , -0.0955],      dtype=float64)

A quick analysis

Lets take a quick look at the linear regression loss function and gradient descent and think a bit about its cost(s), we can define the loss function and its gradient as follows:

\[ \begin{aligned} f(\boldsymbol{\beta}) &= (y - \boldsymbol{X} \boldsymbol{\beta})^T (y - \boldsymbol{X} \boldsymbol{\beta}) \\ \\ \nabla f(\boldsymbol{\beta}) &= 2 \boldsymbol{X}^T(\boldsymbol{X}\boldsymbol{\beta} - \boldsymbol{y})\\ %&= %\left[ % \begin{matrix} % 2 \boldsymbol{X}_{\cdot 1}^T(\boldsymbol{X}_{\cdot 1}\boldsymbol{\beta}_1 - \boldsymbol{y}) \\ % 2 \boldsymbol{X}_{\cdot 2}^T(\boldsymbol{X}_{\cdot 2}\boldsymbol{\beta}_2 - \boldsymbol{y}) \\ % \vdots \\ % 2 \boldsymbol{X}_{\cdot k}^T(\boldsymbol{X}_{\cdot k}\boldsymbol{\beta}_k - \boldsymbol{y}) % \end{matrix} %\right] \end{aligned} \]

What are the costs of calculating the loss function and gradient respectively in terms of \(n\) and \(k\)?

  • Calculating the loss costs \({O}(nk)\)

  • Calculating the gradient costs \({O}(n^2k)\)

Stochastic Gradient Descent

This is a variant of gradient descent where rather than using all \(n\) data points we randomly sample one at a time and use that single point to make our gradient step.

  • Sampling of observations can be done with or without replacement

  • Will take more steps to converge but each step is now cheaper to compute

  • SGD has slower asymptotic convergence than GD, but is often faster in practice

  • Generally requires the learning rate to shrink as a function of iteration to guarantee convergence

SGD - Linear Regression

def sto_grad_desc_lm(X, y, beta, step, max_step=50, seed=1234, replace=True):
  X = jnp.c_[jnp.ones(X.shape[0]), X]
  f = lambda beta: jnp.sum((y - X @ beta)**2)
  grad = lambda beta, i: 2*X[i,:] * (X[i,:]@beta - y[i])
  n, k = X.shape

  res = {"x": [beta], "loss": [f(beta).item()], "iter": [0]}
  rng = np.random.default_rng(seed)

  for i in range(max_step):
    if replace:
      js = rng.integers(0,n,n)
    else:
      js = np.array(range(n))
      rng.shuffle(js)

    for j in js:
      beta = beta - grad(beta, j) * step
      res["x"].append(beta)
      res["loss"].append(f(beta).item())
      res["iter"].append(res["iter"][-1]+1)
    
  return res

Fitting

np.r_[lm.intercept_, lm.coef_]
array([ 3.0616, -0.0121, -0.0096,  0.096 ,  9.6955, 43.406 ,  0.0253,
        0.0284,  0.0962,  0.1069, 34.4884,  9.3445, -0.0165, -0.0147,
       -0.0396,  0.0969, -0.1057, -0.0943,  0.11  , -0.0096, -0.0875])
sgd_lm_rep = sto_grad_desc_lm(
  X, y, np.zeros(X.shape[1]+1), 
  step = 0.001, max_step=25, replace=True
)
sgd_lm_rep["x"][-1]
Array([ 3.0342, -0.0896, -0.0556,  0.045 ,  9.718 , 43.423 ,  0.0448,
        0.0839,  0.0512,  0.1082, 34.4395,  9.311 ,  0.0179,  0.035 ,
        0.0028,  0.1128, -0.0999, -0.123 ,  0.1248,  0.0134, -0.1609],      dtype=float64)
sgd_lm_worep = sto_grad_desc_lm(
  X, y, np.zeros(X.shape[1]+1), 
  step = 0.001, max_step=25, replace=False
)
sgd_lm_worep["x"][-1]
Array([ 3.065 , -0.0102, -0.0065,  0.0996,  9.7028, 43.4078,  0.0298,
        0.0349,  0.0952,  0.0933, 34.4662,  9.336 , -0.0252, -0.0115,
       -0.0462,  0.0774, -0.1103, -0.0746,  0.1096, -0.0269, -0.0995],      dtype=float64)

Using Epochs

Generally, rather than thinking in steps we use epochs instead - an epoch is one complete pass through the data.

A bigger example

from sklearn.datasets import make_regression
X, y, coef = make_regression(
  n_samples=10000, n_features=20, n_informative=4, 
  bias=3, noise=1, random_state=1234, coef=True
)
lm = LinearRegression().fit(X,y)
np.r_[lm.intercept_, lm.coef_]
array([ 3.0081,  0.0088,  0.0002,  0.0021,  0.0037,  0.0033,  0.026 ,
       -0.0006,  0.0005, 12.2771, 44.4939,  3.6423,  0.0168, 61.3938,
       -0.0012, -0.0056,  0.014 , -0.0093, -0.0056,  0.0024,  0.0217])

Fitting

gd_lm = grad_desc_lm(
  X, y, np.zeros(X.shape[1]+1),
  step = 0.00005, max_step=3
)
gd_lm["x"][-1]
Array([ 3.01  ,  0.0118,  0.0029,  0.0033, -0.0014,  0.0028,  0.0252,
       -0.0005,  0.0009, 12.2793, 44.4961,  3.6409,  0.0165, 61.3964,
        0.0011,  0.0005,  0.011 , -0.0134, -0.0045,  0.0028,  0.0244],      dtype=float64)
sgd_lm_rep = sto_grad_desc_lm(
  X, y, np.zeros(X.shape[1]+1), 
  step = 0.001, max_step=3, replace=True
)
sgd_lm_rep["x"][-1]
Array([ 2.9968,  0.0452,  0.0346,  0.0205, -0.0286, -0.0518, -0.0534,
       -0.0142,  0.0371, 12.2115, 44.5628,  3.6736,  0.0249, 61.4223,
       -0.0094, -0.0877,  0.0134, -0.0324, -0.0178,  0.0222,  0.0089],      dtype=float64)
sgd_lm_worep = sto_grad_desc_lm(
  X, y, np.zeros(X.shape[1]+1), 
  step = 0.001, max_step=3, replace=False
)
sgd_lm_worep["x"][-1]
Array([ 2.9756,  0.0149, -0.0011, -0.0047,  0.011 , -0.0176,  0.0251,
       -0.0244, -0.0328, 12.2937, 44.5281,  3.6333, -0.0158, 61.4162,
        0.0541,  0.0064, -0.0231, -0.0014, -0.0144, -0.0299,  0.0141],      dtype=float64)

Mini batch gradient descent

This is a further variant of stochastic gradient descent where a mini batch of \(m\) data points is selected for each gradient update,

  • The idea is to find a balance between the cost of increasing the data size vs the speed-up of vectorized calculations.

  • More updates per epoch than GD, but less than SGD

MBGD - Linear Regression

def mb_grad_desc_lm(X, y, beta, step, batch_size = 10, max_step=50, seed=1234, replace=True):
  X = jnp.c_[jnp.ones(X.shape[0]), X]
  f = lambda beta: jnp.sum((y - X @ beta)**2)
  grad = lambda beta, i: 2*X[i,:].T @ (X[i,:]@beta - y[i])
  n, k = X.shape

  res = {"x": [beta], "loss": [f(beta).item()], "iter": [0]}
  rng = np.random.default_rng(seed)

  for i in range(max_step):
    if replace:
      js = rng.integers(0,n,n)
    else:
      js = np.array(range(n))
      rng.shuffle(js)

    for j in js.reshape(-1, batch_size):
      beta = beta - grad(beta, j) * step
      res["x"].append(beta)
      res["loss"].append(f(beta).item())
      res["iter"].append(res["iter"][-1]+1)
    
  return res

Fitting

lm = LinearRegression().fit(X,y)
np.r_[lm.intercept_, lm.coef_]
array([ 3.0081,  0.0088,  0.0002,  0.0021,  0.0037,  0.0033,  0.026 ,
       -0.0006,  0.0005, 12.2771, 44.4939,  3.6423,  0.0168, 61.3938,
       -0.0012, -0.0056,  0.014 , -0.0093, -0.0056,  0.0024,  0.0217])
sizes = [10,50,100]
mbgd = { size: mb_grad_desc_lm(
           X, y, np.zeros(X.shape[1]+1), batch_size=size,
           step = 0.001, max_step=3, replace=False
         )
         for size in sizes }
Batch size: 10
[ 2.9754  0.0154  0.0004 -0.0038  0.0118 -0.0171  0.0248 -0.0242 -0.0336
 12.2937 44.5285  3.6334 -0.0156 61.417   0.0546  0.0075 -0.023  -0.0004
 -0.0135 -0.031   0.0138]

Batch size: 50
[ 2.9761  0.0107 -0.001  -0.0029  0.0119 -0.0161  0.0238 -0.0243 -0.0374
 12.2943 44.5304  3.6337 -0.0172 61.4199  0.0557  0.0068 -0.0246 -0.0015
 -0.0134 -0.0326  0.0127]

Batch size: 100
[ 2.973   0.0094 -0.0001 -0.0038  0.0123 -0.0171  0.0223 -0.0263 -0.0416
 12.2923 44.5302  3.635  -0.0155 61.4176  0.0565  0.009  -0.0258 -0.002
 -0.014  -0.0353  0.0045]

Results

(0.0, 0.125)

A bit of theory

We’ve talked a bit about the computational side of things, but why do these approaches work at all?

In statistics and machine learning many of our problems have a form that looks like,

\[ \underset{\theta}{\text{arg min}} \; \ell(\boldsymbol{X}, \theta) = \underset{\theta}{\text{arg min}} \; \frac{1}{n} \sum_{i=1}^n \ell(\boldsymbol{X}_i, \theta) \]

which means that the gradient of the loss function is given by

\[ \nabla \ell(\boldsymbol{X}, \theta) = \frac{1}{n} \sum_{i=1}^n \nabla \ell(\boldsymbol{X}_i, \theta) \]

\[ \nabla \ell(\boldsymbol{X}, \theta) \approx \frac{1}{|B|} \sum_{i \in B}^n \nabla \ell(\boldsymbol{X}_i, \theta) \]

SGD estimator

Because we are sampling \(B\) randomly, then our SGD and mini batch GD approximations are unbiased estimated of the full gradient,

\[ E\left[ \frac{1}{|B|} \sum_{i \in B}^n \nabla \ell(\boldsymbol{X}_i, \theta) \right] = \frac{1}{n} \sum_{i=1}^n \nabla \ell(\boldsymbol{X}_i, \theta) = \nabla \ell(\boldsymbol{X}, \theta) \]

Each update can be viewed as a noisy gradient descent step (gradient + zero mean noise).

  • The difference between mini batch and stochastic gradient descent is that by increasing the computation cost per step we are reducing the noise variance for that step

Limitations

As mentioned previously we need to be a bit careful with learning rates and convergence for both of these methods. So far, our approach has been naive and runs for a fixed number of epochs.

If we want to use a convergence criterion we need to keep the following in mind:

  • Let \(\theta^*\) be a global / local minimizer of our loss function \(\ell(\boldsymbol{X},\theta)\), then by definition \(\nabla \ell(\boldsymbol{X},\theta^*) = 0\)

  • The issue is that our gradient approximation, \[ \frac{1}{|B|} \sum_{i \in B}^n \nabla \ell(\boldsymbol{X}_i, \theta) \ne 0 \] as \(B\) is a subset of the data, therefore our algorithm will keep taking steps / never converge.

Solution

The practical solution to this is to implement a learning rate schedule which generally shrink the learning rate / step size over time to ensure convergence.

The choice of the exact learning schedule is problem specific, and is usually about finding the balance of how quickly to shrink the step size.

Some common examples:

  • Piecewise constant - \(\eta_t = \eta_i \text{ if } t_i \leq t \leq t_{i+1}\)

  • Exponential decay - \(\eta_t = \eta_0 e^{-\lambda t}\)

  • Polynomial decay - \(\eta_t = \eta_0 (\beta t+1)^{-\alpha}\)

There are many more approaches including more exotic techniques that allow the learning rate to increase and decrease to help the optimizer better explore the objective function and in some cases escape local optima.

Adaptive updates & Momentum

AdaGrad

This approach was proposed in by Duchi, Hazan, & Singer in 2011 and is based on the idea of scaling the learning rates for the current step by the sum of the square gradients of previous steps - this has the effect of shrinking the step size of dimensions with large previous gradients.

\[ \begin{aligned} \boldsymbol{\theta}_{t+1} &= \boldsymbol{\theta}_t - \eta_t \frac{1}{\sqrt{\boldsymbol{s}_t + \epsilon}} \odot \nabla \ell(\boldsymbol{X}, \boldsymbol{\theta}_t)\\ \boldsymbol{s}_t &= \sum_{i=1}^t \left(\nabla \ell(\boldsymbol{X}, \boldsymbol{\theta}_i)\right)^2 \end{aligned} \]

where \(\epsilon\) is a small constant (i.e. \(10^{-8}\)) to avoid division by zero.

Implementation

def adagrad_lm(X, y, beta, step, batch_size = 10, max_step=50, seed=1234, replace=True, eps=1e-8):
  X = jnp.c_[jnp.ones(X.shape[0]), X]
  f = lambda beta: jnp.sum((y - X @ beta)**2)
  grad = lambda beta, i: 2*X[i,:].T @ (X[i,:]@beta - y[i])
  n, k = X.shape

  res = {"x": [beta], "loss": [f(beta).item()], "iter": [0]}
  rng = np.random.default_rng(seed)

  S = np.zeros(k)

  for i in range(max_step):
    if replace:
      js = rng.integers(0,n,n)
    else:
      js = np.array(range(n))
      rng.shuffle(js)

    for j in js.reshape(-1, batch_size):
      G = grad(beta, j)
      S += G**2
      
      beta = beta - step * (1/np.sqrt(S + eps)) * G
      
      res["x"].append(beta)
      res["loss"].append(f(beta).item())
      res["iter"].append(res["iter"][-1]+1)
    
  return res

A medium example

from sklearn.datasets import make_regression
X, y, coef = make_regression(
  n_samples=1000, n_features=20, n_informative=4, 
  bias=3, noise=1, random_state=1234, coef=True
)
lm = LinearRegression().fit(X,y)
np.r_[lm.intercept_, lm.coef_]
array([ 3.0034,  0.0144,  0.0304, -0.0306,  0.0334,  0.0292, -0.0214,
       30.9685,  0.0189, -0.005 ,  0.005 ,  0.0016, 40.5613, -0.0422,
       44.6158, -0.0495, 72.2522,  0.002 , -0.0336,  0.0148,  0.0516])

Fitting

sizes = [1, 25, 50, 1000]
lrs = [10] * 4
algos = ["AdaGrad - SGD", "AdaGrad - MBGD (25)", "AdaGrad - MBGD (50)", "AdaGrad - GD"]

adagrad = { size: adagrad_lm(
                    X, y, np.zeros(X.shape[1]+1), batch_size=size,
                    step = lr, max_step=15, replace=False
                  )
            for size, lr in zip(sizes,lrs) }
AdaGrad - SGD
[ 3.0218  0.0577  0.0304  0.0376 -0.0169 -0.0808 -0.0156 31.0294  0.0226
  0.0419 -0.0442  0.1085 40.7294  0.0879 44.675  -0.246  72.3294  0.0787
 -0.0684  0.0128  0.0556]

AdaGrad - MBGD (25)
[ 3.0464  0.1249  0.0086  0.0118 -0.0177 -0.0265 -0.0327 30.9548 -0.0094
  0.0146  0.0122  0.1039 40.6525  0.0591 44.65   -0.1758 72.27    0.0771
 -0.0531 -0.0165  0.0682]

AdaGrad - MBGD (50)
[ 3.0341  0.1253  0.0233 -0.0095  0.0094 -0.0057 -0.0136 30.9632 -0.0047
 -0.0028  0.0226  0.0697 40.6104  0.024  44.638  -0.1379 72.2525  0.0463
 -0.0467 -0.0267  0.0971]

AdaGrad - GD
[ 2.5491  0.1287 -1.0413  1.0849  0.2096  0.9691 -0.342  31.1497 -0.4613
  0.8328  0.0183  1.3411 38.2487 -0.4805 40.5215  0.522  49.9215 -0.6421
 -0.1804 -1.0691 -0.8411]

Results

RMSProp

With AdaGrad the denominator involving \(\boldsymbol{s}_t\) gets larger as \(t\) increases, but in some cases it gets too large too fast to effectively explore the loss function. An alternative is to use a moving average of the past squared gradients instead.

RMSProp replaces AdaGrad’s \(\boldsymbol{s}_t\) with the following,

\[ \boldsymbol{s}_t = \beta \, \boldsymbol{s}_{t-1} + (1-\beta) \, (\nabla \ell(\boldsymbol{X},\boldsymbol{\theta}_t))^2 \\ \boldsymbol{s}_0 = \boldsymbol{0} \]

in practice a value of \(\beta \approx 0.9\) is used.

Implementation

def rmsprop_lm(X, y, beta, step, batch_size = 10, max_step=50, seed=1234, replace=True, eps=1e-8, b=0.9):
  X = jnp.c_[jnp.ones(X.shape[0]), X]
  f = lambda beta: jnp.sum((y - X @ beta)**2)
  grad = lambda beta, i: 2*X[i,:].T @ (X[i,:]@beta - y[i])
  n, k = X.shape
  
  res = {"x": [beta], "loss": [f(beta).item()], "iter": [0]}
  rng = np.random.default_rng(seed)
  
  S = np.zeros(k)
  
  for i in range(max_step):
    if replace:
      js = rng.integers(0,n,n)
    else:
      js = np.array(range(n))
      rng.shuffle(js)
    
    for j in js.reshape(-1, batch_size):
      G = grad(beta, j)
      S = b*S + (1-b) * G**2
      
      beta = beta - step * (1/np.sqrt(S + eps)) * G
      
      res["x"].append(beta)
      res["loss"].append(f(beta).item())
      res["iter"].append(res["iter"][-1]+1)
    
  return res

Fitting

sizes = [1, 25, 50, 1000]
lrs = [0.01, 0.1, 0.25, 1]
algos = ["RMSProp - SGD", "RMSProp - MBGD (25)", "RMSProp - MBGD (50)", "RMSProp - GD"]

rmsprop = { size: rmsprop_lm(
                    X, y, np.zeros(X.shape[1]+1), batch_size=size,
                    step = lr, max_step=25, replace=False
                  )
            for size, lr in zip(sizes,lrs) }
RMSProp - SGD
[ 3.0721  0.0241 -0.0229 -0.0109  0.0477  0.0879 -0.019  30.968   0.063
 -0.0005  0.0926 -0.0024 40.5439 -0.0316 44.6041 -0.0575 72.2169 -0.0054
 -0.0461 -0.0527  0.0905]

RMSProp - MBGD (25)
[ 3.148   0.0316 -0.1395 -0.0606  0.0544  0.0715 -0.0488 30.9947  0.0788
 -0.0545 -0.0162  0.0012 40.4057 -0.0185 44.6314 -0.1133 72.2111 -0.0824
 -0.0153 -0.0569  0.117 ]

RMSProp - MBGD (50)
[ 3.3168 -0.0474 -0.2298 -0.0102  0.0364 -0.0025 -0.0444 31.0121  0.152
 -0.0367 -0.0114 -0.0227 40.485  -0.0842 44.7648 -0.2086 72.221  -0.0022
 -0.0477 -0.2108  0.4098]

RMSProp - GD
[ 1.6267 -0.159  -1.8509  1.9847  0.9328  2.265  -0.7434 26.0131 -0.579
  1.2681 -0.6497  2.8729 28.3096 -0.5988 28.8862  0.5487 30.9147 -0.9718
 -0.5039 -2.0638 -1.3314]

Results

Momentum

Rather then just using the gradient information at our current location it may be benefitial to use information from our previous steps as well. A general setup for this type approach looks like,

\[ \boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \eta \, \boldsymbol{m}_t \\ \boldsymbol{m}_t = \beta \, \boldsymbol{m}_{t-1} + (1-\beta) \, \nabla \ell(\boldsymbol{X}, \boldsymbol{\theta}_t) \]

where \(\eta\) is our step size and \(\beta\) determines the weighting of the current gradient and the previous gradients.

If you have taken a course on time series, this has a flavor that looks a lot like moving average models,

\[ \boldsymbol{m}_t = (1-\beta) \, \nabla \ell(\boldsymbol{X}, \boldsymbol{\theta}_t) + \beta(1-\beta) \, \, \nabla \ell(\boldsymbol{X}, \boldsymbol{\theta}_{t-1}) + \beta^2(1-\beta) \, \, \nabla \ell(\boldsymbol{X}, \boldsymbol{\theta}_{t-2}) + \cdots \]

Adam

The “adaptive moment estimation” algorithm is a combination of momentum with RMSProp,

\[ \begin{aligned} \theta_{t+1} &= \theta_t - \eta_t \frac{\boldsymbol{m_t}}{\sqrt{\boldsymbol{s}_t + \epsilon}} \\ \boldsymbol{m}_t &= \beta_1 \, \boldsymbol{m}_{t-1} + (1-\beta_1) \, \nabla \ell(\boldsymbol{X}, \theta_t) \\ \boldsymbol{s}_t &= \beta_2 \, \boldsymbol{s}_{t-1} + (1-\beta_2) \, (\nabla \ell(\boldsymbol{X},\boldsymbol{\theta}_t))^2 \\ \end{aligned} \]

Note that RMSProp is a special case of Adam when \(\beta_1 = 0\).

Adam is widely used in practice is and is commonly available within tools like Torch for fitting NN models.

In typical use \(\beta_1=0.9\), \(\beta_2=0.999\), \(\epsilon=10^{-6}\), and \(\eta_t=0.001\) are used. As the learning rate is not guaranteed to decrease over time, the algorithm is not guaranteed to converge.

Bias corrections

One small alteration that was suggested by the original others and is commonly used is to correct for the bias towards small values in the initial estimates of \(\boldsymbol{m}_t\) and \(\boldsymbol{s}_t\). In which case they are replaced with,

\[ \begin{aligned} \hat{\boldsymbol{m}}_t &= \boldsymbol{m} / (1-{\beta_1}^t) \\ \hat{\boldsymbol{s}}_t &=\boldsymbol{s}_t / (1-{\beta_2}^t) \\ \end{aligned} \]

Implementation

def adam_lm(X, y, beta, step=0.001, batch_size = 10, max_step=50, seed=1234, replace=True, eps=1e-6, b1=0.9, b2=0.999):
  X = jnp.c_[jnp.ones(X.shape[0]), X]
  f = lambda beta: jnp.sum((y - X @ beta)**2)
  grad = lambda beta, i: 2*X[i,:].T @ (X[i,:]@beta - y[i])
  n, k = X.shape
  
  res = {"x": [beta], "loss": [f(beta).item()], "iter": [0]}
  rng = np.random.default_rng(seed)
  
  S = np.zeros(k)
  M = np.zeros(k)
  t = 0
  
  for i in range(max_step):
    if replace:
      js = rng.integers(0,n,n)
    else:
      js = np.array(range(n))
      rng.shuffle(js)
    
    for j in js.reshape(-1, batch_size):
      t += 1
      G = grad(beta, j)
      S = b2*S + (1-b2) * G**2
      M = b1*M + (1-b1) * G
      
      M_hat = M / (1-b1**t)
      S_hat = S / (1-b2**t)
      
      beta = beta - step * (M_hat / np.sqrt(S_hat + eps))
      
      res["x"].append(beta)
      res["loss"].append(f(beta).item())
      res["iter"].append(res["iter"][-1]+1)
    
  return res

Fitting

sizes = [1, 25, 50, 1000]
lrs = [0.01, 0.5, 0.75, 1]
algos = ["Adam - SGD", "Adam - MBGD (25)", "Adam - MBGD (50)", "Adam - GD"]

adam = { size: adam_lm(
                 X, y, np.zeros(X.shape[1]+1), batch_size=size,
                 step=lr, max_step=25, replace=False
               )
         for size, lr in zip(sizes,lrs) }
Adam - SGD
[ 3.0729  0.0375 -0.0207 -0.0162  0.1077  0.0935  0.0036 31.0195  0.0789
  0.0502  0.0577  0.0243 40.5269 -0.0207 44.5955 -0.0656 72.2082 -0.0416
 -0.0112 -0.0108  0.0732]

Adam - MBGD (25)
[ 2.9533 -0.0291  0.0937 -0.1201  0.0327  0.0753 -0.0276 30.9863  0.0151
  0.1048  0.046   0.0183 40.5607  0.012  44.5998 -0.0339 72.2525  0.0056
  0.0034 -0.0059  0.0159]

Adam - MBGD (50)
[ 2.9813 -0.0251  0.0356 -0.0671  0.0195  0.0527 -0.0149 30.9626  0.0383
  0.0202  0.0046 -0.0108 40.5586 -0.0162 44.6244 -0.0177 72.25    0.0053
 -0.0439 -0.0183  0.038 ]

Adam - GD
[ 1.6529 -0.045  -1.7455  1.8308  0.7478  2.5328 -0.3475 22.8207 -1.9976
  1.59   -0.1611  2.5457 23.488  -1.2908 23.6385  0.9278 24.1926 -0.3689
 -0.3436 -1.4462 -2.3327]

Results